from typing import Dict, List, Tuple, Optional
import pandas as pd
import os
import random
import time
import argparse
from multiprocessing import Pool
import numpy as np
from scipy.stats import spearmanr, kendalltau
from tqdm import tqdm


from elo_processor import DataProcessor, EloProcessor

def run_bootstrap_wrapper(args_tuple):
    df, valid_users, n_users, n_comp_per_user, model = args_tuple
    return run_bootstrap(df, valid_users, n_users, n_comp_per_user, model)



def run_bootstrap(df: pd.DataFrame, valid_users: List[str], n_users: int, n_comp_per_user: int, model: str):
    if n_users is None:
        # Use all available users
        valid_users_subset = random.choices(valid_users, k=len(valid_users))
    else:
        # Sample the specified number of users
        valid_users_subset = random.choices(valid_users, k=n_users)
    # Bootstrap sampling for Google Elo format
    df_bootstrap = pd.DataFrame()

    for i, user in enumerate(valid_users_subset):
        df_subset = df[df['answerer'] == user].copy()
        df_subset['answerer'] = f"user_{i}"
        sampled_comparisons = df_subset
        df_bootstrap = pd.concat([df_bootstrap, sampled_comparisons])
    valid_users_subset_new = df_bootstrap['answerer'].unique()
    processor = EloProcessor(df_bootstrap, valid_users_subset_new)
    results, time = processor.process(df=df_bootstrap, valid_users=valid_users_subset_new, model=model)

    return results, time

def main():
    parser = argparse.ArgumentParser(description='Run bootstrap analysis for ELO ratings')
    parser.add_argument('--project_name', type=str, required=True, help='Name of the project')
    parser.add_argument('--csv_name', type=str, required=True, help='Name of the CSV file to process')
    parser.add_argument('--model', type=str, default='google_elo', help='Model to use for analysis')
    parser.add_argument('--n_users', type=int, default=None, help='Number of users to sample (default: all available users)')
    parser.add_argument('--n_comp_per_user', type=int, default=None, help='Number of comparisons per user to sample (default: all available)')
    parser.add_argument('--bootstrap_n', type=int, required=True, help='Number of bootstraps')
    parser.add_argument('--main_csv_name', type=str, default=None, help='Name of the CSV file to process for main run')
    parser.add_argument('--results_path', type=str, default=None, help='Custom path to save results (if None, uses projects/{project_name}/results)')
    
    args = parser.parse_args()
    if args.n_users is not None:
        print(f"Running bootstrap for {args.project_name} with data {args.csv_name}, model {args.model} and {args.n_users} users")
    else:
        print(f"Running bootstrap for {args.project_name} with data {args.csv_name}, model {args.model} and all available users")
    if args.n_comp_per_user is not None:
        print(f"Sampling {args.n_comp_per_user} comparisons per user")
    else:
        print("Sampling all available comparisons per user")
    
    data_path = os.path.join('projects', args.project_name, 'data', args.csv_name)
    
    # Load and process data
    data_processor = DataProcessor(data_path)
    if not data_processor.load_data():
        print("Failed to load data")
        return
        
    df = data_processor.df
    valid_users = data_processor.get_valid_users()
    n_users = args.n_users

    # run bootstrap, bootstrap n times, parallelize on multiple cores
    cpu_count = 10
    print(f"Starting {args.bootstrap_n} bootstrap iterations using {cpu_count} CPU cores...")
    
    with Pool(processes=cpu_count) as pool:
        args_list = [(df, valid_users, n_users, args.n_comp_per_user, args.model) for _ in range(args.bootstrap_n)]
        results = list(tqdm(pool.imap(run_bootstrap_wrapper, args_list), 
                           total=args.bootstrap_n, 
                           desc="Bootstrap iterations"))
    times = [result[1] for result in results]
    results = [result[0] for result in results]    

    if args.main_csv_name is not None:
        data_processor_main = DataProcessor(os.path.join('projects', args.project_name, 'data', args.main_csv_name))
        if not data_processor_main.load_data():
            print("Failed to load data")
            return
        df_main = data_processor_main.df
        valid_users_main = data_processor_main.get_valid_users()
    else:
        df_main = df
        valid_users_main = valid_users
    print(len(df_main), len(valid_users_main))
    processor_main = EloProcessor(df_main, valid_users_main)
    results_main, time_main = processor_main.process(df=df_main, valid_users=valid_users_main, model=args.model)
    print(results_main)
    
    # calculate top1 agreement, spearman correlation, and kendall's tau between all results and main result
    top1_agreement = []
    spearman_correlation = []
    kendall_tau = []
    for result in tqdm(results, desc="Calculating correlations"):
        if result.empty:
            continue
        top1_agreement.append(result['Method'].iloc[0] == results_main['Method'].iloc[0])
        
        # align methods properly for correlation calculation
        result_sorted = result.sort_values('Method').reset_index(drop=True)
        results_main_sorted = results_main.sort_values('Method').reset_index(drop=True)
        # only keep methods that exist in both
        common_methods = set(result_sorted['Method']) & set(results_main_sorted['Method'])
        result_aligned = result_sorted[result_sorted['Method'].isin(common_methods)].sort_values('Method')
        results_main_aligned = results_main_sorted[results_main_sorted['Method'].isin(common_methods)].sort_values('Method')
        
        if len(result_aligned) > 1:  # need at least 2 points for correlation
            spearman_correlation.append(spearmanr(result_aligned['ELO Score'], results_main_aligned['ELO Score']).correlation)
            kendall_tau.append(kendalltau(result_aligned['ELO Score'], results_main_aligned['ELO Score']).correlation)
        else:
            spearman_correlation.append(np.nan)
            kendall_tau.append(np.nan)

    print(f"Top1 agreement: {np.mean(top1_agreement):.3f}")
    print(f"Spearman correlation: {np.mean(spearman_correlation):.6f}")
    print(f"Kendall's tau: {np.mean(kendall_tau):.6f}")

    # save the results to a csv
    results_df = pd.DataFrame({
        'Top1 agreement': top1_agreement,
        'Spearman correlation': spearman_correlation,
        'Kendall\'s tau': kendall_tau,
        'Time': times
    })

    csv_name = args.csv_name.replace('.csv', '')

    # Use custom results path if provided, otherwise use default main results folder
    if args.results_path is not None:
        results_path = args.results_path
    else:
        results_path = 'results'
    results_path = os.path.join(results_path, args.project_name, csv_name, args.model)
    
    if args.n_users is not None:
        results_path = os.path.join(results_path, str(args.n_users))
    else:
        results_path = os.path.join(results_path, 'all')
        
    if args.n_comp_per_user is not None:
        results_path = os.path.join(results_path, str(args.n_comp_per_user))
    else:
        results_path = os.path.join(results_path, 'all')
    
    results_path = os.path.join(results_path, 'bootstrap_results.csv')
    
    os.makedirs(os.path.dirname(results_path), exist_ok=True)
    results_df.to_csv(results_path, index=False)
    print(f"Results saved to {results_path}")


if __name__ == "__main__":
    main()